SQL 쥬피터 노트북
Download Notebook

OpenAI 설정

In [45]:
코드
import openai
import os

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

openai.api_key  = os.getenv('ENV_OPENAI_API_KEY')

데이터베이스

코드
response = openai.ChatCompletion.create(
  model="gpt-3.5-turbo",
  messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Who won the super bowl in 1973?"},
        {"role": "assistant", "content": "The Miami Dolphins won the Super Bowl in 1973."},
        {"role": "user", "content": "Where was it played?"}
    ]
)    
In [6]:
코드
print(response['choices'][0]['message']['content'])
The Super Bowl in 1973, which was Super Bowl VII, was played at the Los Angeles Memorial Coliseum in Los Angeles, California.
In [8]:
코드
response
<OpenAIObject chat.completion id=chatcmpl-7DMvdzWvhcClj1VZI4tMpcyuSLupN at 0x1f3f7a9d720> JSON: {
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "message": {
        "content": "The Super Bowl in 1973, which was Super Bowl VII, was played at the Los Angeles Memorial Coliseum in Los Angeles, California.",
        "role": "assistant"
      }
    }
  ],
  "created": 1683420757,
  "id": "chatcmpl-7DMvdzWvhcClj1VZI4tMpcyuSLupN",
  "model": "gpt-3.5-turbo-0301",
  "object": "chat.completion",
  "usage": {
    "completion_tokens": 29,
    "prompt_tokens": 56,
    "total_tokens": 85
  }
}
코드
system_prompt = """
    You are the world's best SQL expert. Help me convert natural language to valid SQL queries. Only respond with valid SQL queries, nothing else.
    You must learn the column names based on the information the user gives you and build valid SQL queries. Never guess the column names.
    These are the examples:

    query: get all people names
    answer: SELECT name from people;

    query: get all cars whose owner name is aaron
    answer: SELECT c.* FROM people p JOIN cars c ON p.id = c.owner_id WHERE p.name = 'aaron';
"""

user_prompt = f"""
    This is my database information:
    {self.db_schema_info}

    query: {q}
    answer:
"""

completion = self.openai_client.ChatCompletion.create(
    model=self.openai_model_name,
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ],
)

return completion.choices[0].message.content
In [11]:
코드
def generate_sql_query(prompt, max_tokens=100, temperature=0, top_p=1.0, frequency_penalty=0.0, presence_penalty=0.0, stop=["\n"]):
    response = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty,
        stop=stop
    )
    return response.choices[0].text
In [12]:
코드
generate_sql_query("Get all the users that are older than 25 years old")
''

SQL 생성함수

스크립트

In [49]:
코드
prompt = "Get all the users that are older than 35 years old"
model = "text-davinci-002"
temperature = 0.0
max_tokens = 50

response = openai.Completion.create(
    engine=model,
    prompt=prompt,
    temperature=temperature,
    max_tokens=max_tokens,
)

print(response.choices[0].text)


SELECT * FROM users WHERE age > 35;
In [48]:
코드
prompt = "users 테이블에서 35세 이상 사용자를 추출하는 SQL 코드를 작성하세요"
model = "text-davinci-002"
temperature = 0.5
max_tokens = 50

response = openai.Completion.create(
    engine=model,
    prompt=prompt,
    temperature=temperature,
    max_tokens=max_tokens,
)

print(response.choices[0].text)
.

```mysql
SELECT * FROM users WHERE age >= 35;
```

함수

In [52]:
코드
def generate_sql_query(prompt, max_tokens = 100):
    response = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompt,
        temperature=0,
        max_tokens=max_tokens
    )
    return response.choices[0].text

generate_sql_query("Get all the users that are older than 35 years old from users table")
'\n\nSELECT * FROM users WHERE age > 35;'

CLI 구현

In [18]:
코드
prompt = input("SQL문 작성 프롬프트를 입력하세요: ")  # user 테이블에서 35세이상 사용자 추출하세요
print(generate_sql_query(prompt))
SQL문 작성 프롬프트를 입력하세요:  user 테이블에서 35세이상 사용자 추출하세요
.

SELECT * FROM user WHERE age >= 35;

Langchain + SQL

In [53]:
코드
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
dvd_db = SQLDatabase.from_uri("sqlite:///survey.db")
llm = OpenAI(temperature=0)

db_chain = SQLDatabaseChain.from_llm(llm, dvd_db, verbose=True)
db_chain.run("How many tables are there? If there exists tables, list table names")


> Entering new SQLDatabaseChain chain...
How many tables are there? If there exists tables, list table names
SQLQuery: SELECT name FROM sqlite_master WHERE type='table';
SQLResult: [('Person',), ('Site',), ('Survey',), ('Visited',)]
Answer: There are 4 tables: Person, Site, Survey, Visited.
> Finished chain.
' There are 4 tables: Person, Site, Survey, Visited.'
In [16]:
코드
db_chain.run("Write a query that selects distinct dates from the Visited table.")


> Entering new SQLDatabaseChain chain...
Write a query that selects distinct dates from the Visited table.
SQLQuery: SELECT DISTINCT dated FROM Visited;
SQLResult: [('1927-02-08',), ('1927-02-10',), ('1930-01-07',), ('1930-01-12',), ('1930-02-26',), (None,), ('1932-01-14',), ('1932-03-22',)]
Answer: The distinct dates from the Visited table are 1927-02-08, 1927-02-10, 1930-01-07, 1930-01-12, 1930-02-26, None, 1932-01-14, and 1932-03-22.
> Finished chain.
' The distinct dates from the Visited table are 1927-02-08, 1927-02-10, 1930-01-07, 1930-01-12, 1930-02-26, None, 1932-01-14, and 1932-03-22.'
In [55]:
코드
db_chain.run("Write a query that displays the full names of the scientists in the Person table, ordered by family name.")


> Entering new SQLDatabaseChain chain...
Write a query that displays the full names of the scientists in the Person table, ordered by family name.
SQLQuery: SELECT "personal" || ' ' || "family" AS "Full Name" FROM "Person" ORDER BY "family" ASC LIMIT 5;
SQLResult: [('Frank Danforth',), ('William Dyer',), ('Anderson Lake',), ('Frank Pabodie',), ('Valentina Roerich',)]
Answer: The full names of the scientists in the Person table, ordered by family name, are Frank Danforth, William Dyer, Anderson Lake, Frank Pabodie, and Valentina Roerich.
> Finished chain.
' The full names of the scientists in the Person table, ordered by family name, are Frank Danforth, William Dyer, Anderson Lake, Frank Pabodie, and Valentina Roerich.'
In [19]:
코드
db_chain.run('Normalized salinity readings are supposed to be between 0.0 and 1.0. \
              Write a query that selects all records from Survey with salinity values outside this range.\
              print records')


> Entering new SQLDatabaseChain chain...
Normalized salinity readings are supposed to be between 0.0 and 1.0.               Write a query that selects all records from Survey with salinity values outside this range.              print records
SQLQuery: SELECT * FROM Survey WHERE quant = 'sal' AND reading NOT BETWEEN 0.0 AND 1.0;
SQLResult: [(752, 'roe', 'sal', 41.6), (837, 'roe', 'sal', 22.5)]
Answer: There are two records with salinity values outside the range of 0.0 and 1.0: (752, 'roe', 'sal', 41.6) and (837, 'roe', 'sal', 22.5).
> Finished chain.
" There are two records with salinity values outside the range of 0.0 and 1.0: (752, 'roe', 'sal', 41.6) and (837, 'roe', 'sal', 22.5)."
In [57]:
코드
db_chain.run('After further reading, we realize that Valentina Roerich(roe) was reporting salinity as percentages.\
              Write a query that returns all of her salinity measurements from the Survey table with the values divided by 100.')


> Entering new SQLDatabaseChain chain...
After further reading, we realize that Valentina Roerich(roe) was reporting salinity as percentages.              Write a query that returns all of her salinity measurements from the Survey table with the values divided by 100.
SQLQuery: SELECT person, quant, reading/100 AS reading FROM Survey WHERE person = 'roe' AND quant = 'sal';
SQLResult: [('roe', 'sal', 0.41600000000000004), ('roe', 'sal', 0.225)]
Answer: Valentina Roerich reported salinity measurements of 0.416 and 0.225.
> Finished chain.
' Valentina Roerich reported salinity measurements of 0.416 and 0.225.'
In [26]:
코드
db_chain.run('How many temperature readings from quant did Frank Pabodie(pb) record, and what was their average value?')


> Entering new SQLDatabaseChain chain...
How many temperature readings from quant did Frank Pabodie(pb) record, and what was their average value?
SQLQuery: SELECT COUNT(quant), AVG(reading) FROM Survey WHERE person = 'pb' AND quant = 'temp';
SQLResult: [(2, -20.0)]
Answer: Frank Pabodie recorded 2 temperature readings, with an average value of -20.0.
> Finished chain.
' Frank Pabodie recorded 2 temperature readings, with an average value of -20.0.'
In [29]:
코드
db_chain.run('Write a query that lists all sites visited by people named "Frank". write sql step-by-step')


> Entering new SQLDatabaseChain chain...
Write a query that lists all sites visited by people named "Frank". write sql step-by-step
SQLQuery: SELECT Site.name 
FROM Site 
INNER JOIN Visited 
ON Site.name = Visited.site 
INNER JOIN Person 
ON Visited.id = Person.id 
WHERE Person.personal = "Frank"
SQLResult: []
Answer: No sites were visited by people named "Frank".
> Finished chain.
' No sites were visited by people named "Frank".'

DB 스키마

코드
db_chain.run('print database schema info')


> Entering new SQLDatabaseChain chain...
print database schema info
SQLQuery: SELECT * FROM sqlite_master;
SQLResult: [('table', 'Person', 'Person', 2, 'CREATE TABLE Person (id text, personal text, family text)'), ('table', 'Site', 'Site', 3, 'CREATE TABLE Site (name text, lat real, long real)'), ('table', 'Survey', 'Survey', 5, 'CREATE TABLE Survey (taken integer, person text, quant text, reading real)'), ('table', 'Visited', 'Visited', 4, 'CREATE TABLE Visited (id integer, site text, dated text)')]
Answer: The database schema info is: Person (id text, personal text, family text), Site (name text, lat real, long real), Survey (taken integer, person text, quant text, reading real), Visited (id integer, site text, dated text).
> Finished chain.
' The database schema info is: Person (id text, personal text, family text), Site (name text, lat real, long real), Survey (taken integer, person text, quant text, reading real), Visited (id integer, site text, dated text).'
코드
system_prompt = """
    You are the world's best SQL expert. Help me convert natural language to valid SQL queries. Only respond with valid SQL queries, nothing else.
    You must learn the column names based on the information the user gives you and build valid SQL queries. Never guess the column names.
    These are the examples:

    query: get all people names
    answer: SELECT name from people;

    query: get all cars whose owner name is aaron
    answer: SELECT c.* FROM people p JOIN cars c ON p.id = c.owner_id WHERE p.name = 'aaron';
"""

query='Write a query that lists all radiation readings from the DR-1 site step-by-step'

user_prompt = f"""
    This is my database information:
    Person (id text, personal text, family text), 
    Site (name text, lat real, long real), 
    Survey (taken integer, person text, quant text, reading real), 
    Visited (id integer, site text, dated text)

    query: {query}
    answer:
"""



completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ],
)
print(completion.choices[0].message.content)
코드
print(completion.choices[0].message.content)
To write a query that lists all radiation readings from the DR-1 site, we can follow these steps:

1. Determine the ID of the DR-1 site by querying the Site table. We will assume that the ID of the DR-1 site is 'DR-1'.

2. Join the Visited and Survey tables on the site column to only include records from the DR-1 site.

3. Filter the result to only show records with quant column equal to 'rad'.

4. Select the reading column to get the radiation readings.

The SQL query that accomplishes this is:

    SELECT s.reading
    FROM Visited v
    JOIN Survey s ON v.id = s.taken AND v.site = 'DR-1'
    WHERE s.quant = 'rad';